1 Introduction


Objectives: The goal of this kernel is to find the best approach to identify the quality of the wine. We will go through the basic EDA and visually identify the quality via a 3D interactive plot. Moreover, I also applied multiple ML models to make the prediction respectively. Each of the models would have its own strength.

Rpart can give you a nice decision tree plot so you can see the variable more intuitively.

Random Forest is the model most of the time you can run directly with minimum amount of tuning.

xgboost is expected to produce the best result but needs a bit of tuning.

svm is an alternative approach and usually give a less correlated result.

h2o - deeplearning is one of the easiest tool to apply deep learning model. I could potentially use keras but due to the size and the structure of data. I don’t believe deep learning model would outperform xgboost in this case.

Confusion Matrix is used to evaluate the results.

If you have any question, please leave a comment and if you like the kernel, please give me an upvote~ Thanks!


2 Basic Set up



2.1 Load Packages


if (!require("pacman")) install.packages("pacman")
pacman::p_load(tidyverse, skimr, GGally, plotly, viridis, caret, randomForest, e1071, rpart, xgboost, h2o, corrplot)
#pacman::p_load(tidyverse, skimr, GGally, corrplot, plotly, viridis, caret, randomForest, e1071, rpart, rattle, xgboost, h2o)

2.2 Load Dataset


wine <- read_csv("winequality-red.csv")

3 EDA



3.1 First Glimpse via skim


wine %>% skim() %>% kable()
## Skim summary statistics  
##  n obs: 1599    
##  n variables: 12    
## 
## Variable type: integer
## 
## variable   missing   complete   n      mean   sd     p0   p25   p50   p75   p100   hist     
## ---------  --------  ---------  -----  -----  -----  ---  ----  ----  ----  -----  ---------
## quality    0         1599       1599   5.64   0.81   3    5     6     6     8      <U+2581><U+2581><U+2581><U+2587><U+2587><U+2581><U+2582><U+2581> 
## 
## Variable type: numeric
## 
## variable               missing   complete   n      mean    sd       p0      p25    p50     p75    p100   hist     
## ---------------------  --------  ---------  -----  ------  -------  ------  -----  ------  -----  -----  ---------
## alcohol                0         1599       1599   10.42   1.07     8.4     9.5    10.2    11.1   14.9   <U+2582><U+2587><U+2585><U+2583><U+2582><U+2581><U+2581><U+2581> 
## chlorides              0         1599       1599   0.087   0.047    0.012   0.07   0.079   0.09   0.61   <U+2587><U+2583><U+2581><U+2581><U+2581><U+2581><U+2581><U+2581> 
## citric acid            0         1599       1599   0.27    0.19     0       0.09   0.26    0.42   1      <U+2587><U+2585><U+2585><U+2586><U+2582><U+2581><U+2581><U+2581> 
## density                0         1599       1599   1       0.0019   0.99    1      1       1      1      <U+2581><U+2581><U+2583><U+2587><U+2587><U+2582><U+2581><U+2581> 
## fixed acidity          0         1599       1599   8.32    1.74     4.6     7.1    7.9     9.2    15.9   <U+2581><U+2587><U+2587><U+2585><U+2582><U+2581><U+2581><U+2581> 
## free sulfur dioxide    0         1599       1599   15.87   10.46    1       7      14      21     72     <U+2587><U+2587><U+2585><U+2582><U+2581><U+2581><U+2581><U+2581> 
## pH                     0         1599       1599   3.31    0.15     2.74    3.21   3.31    3.4    4.01   <U+2581><U+2581><U+2585><U+2587><U+2585><U+2581><U+2581><U+2581> 
## residual sugar         0         1599       1599   2.54    1.41     0.9     1.9    2.2     2.6    15.5   <U+2587><U+2582><U+2581><U+2581><U+2581><U+2581><U+2581><U+2581> 
## sulphates              0         1599       1599   0.66    0.17     0.33    0.55   0.62    0.73   2      <U+2582><U+2587><U+2582><U+2581><U+2581><U+2581><U+2581><U+2581> 
## total sulfur dioxide   0         1599       1599   46.47   32.9     6       22     38      62     289    <U+2587><U+2585><U+2582><U+2581><U+2581><U+2581><U+2581><U+2581> 
## volatile acidity       0         1599       1599   0.53    0.18     0.12    0.39   0.52    0.64   1.58   <U+2582><U+2587><U+2587><U+2583><U+2581><U+2581><U+2581><U+2581>

3.2 Second Glimpse via Corrplot


wine %>% cor() %>% corrplot.mixed(upper = "ellipse", tl.cex=.8, tl.pos = 'lt', number.cex = .8)


4 Preprocess


colnames(wine) <- wine %>% colnames() %>% str_replace_all(" ","_")
wine$quality <- as.factor(wine$quality)

5 GGally - ggpairs


I have had a quick look and found the following variables: residual_sugar, free_sulfur_dioxide, total_sulfur_dioxide, and chlorides do not have significant different across different quality. Therefore, these variables are not included in the ggpairs model. Further, I found volatile_acidity, sulphates, and alcohol have more significate different across different quality based on the graph below.

wine %>% 
  mutate(quality = as.factor(quality)) %>% 
  select(-c(residual_sugar, free_sulfur_dioxide, total_sulfur_dioxide, chlorides)) %>% 
  ggpairs(aes(color = quality,alpha=0.4),
          columns=1:7,
          lower=list(continuous="points"),
          upper=list(continuous="blank"),
          axisLabels="none", switch="both")


6 Ployly 3D Interactive Graph


wine %>% 
  plot_ly(x=~alcohol,y=~volatile_acidity,z= ~sulphates, color=~quality, hoverinfo = 'text', colors = viridis(3),
          text = ~paste('Quality:', quality,
                        '<br>Alcohol:', alcohol,
                        '<br>Volatile Acidity:', volatile_acidity,
                        '<br>sulphates:', sulphates)) %>% 
  add_markers(opacity = 0.8) %>%
  layout(title = "3D Wine Quality",
         annotations=list(yref='paper',xref="paper",y=1.05,x=1.1, text="quality",showarrow=F),
         scene = list(xaxis = list(title = 'Alcohol'),
                      yaxis = list(title = 'Volatile Acidity'),
                      zaxis = list(title = 'sulphates')))

7 Cross Validation Setup


set.seed(1)
inTrain <- createDataPartition(wine$quality, p=.9, list = F)

train <- wine[inTrain,]
valid <- wine[-inTrain,]
rm(inTrain)

8 Decision Tree via rpart


# rpart
set.seed(1)
rpart_model <- rpart(quality~alcohol+volatile_acidity+citric_acid+
                   density+pH+sulphates, train)

#fancyRpartPlot(rpart_model)

rpart_result <- predict(rpart_model, newdata = valid[,!colnames(valid) %in% c("quality")],type='class')

confusionMatrix(valid$quality,rpart_result)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  3  4  5  6  7  8
##          3  0  0  1  0  0  0
##          4  0  0  2  3  0  0
##          5  0  0 50 18  0  0
##          6  0  0 21 38  4  0
##          7  0  0  1 13  5  0
##          8  0  0  0  1  0  0
## 
## Overall Statistics
##                                         
##                Accuracy : 0.5924        
##                  95% CI : (0.5112, 0.67)
##     No Information Rate : 0.4777        
##     P-Value [Acc > NIR] : 0.00257       
##                                         
##                   Kappa : 0.3201        
##  Mcnemar's Test P-Value : NA            
## 
## Statistics by Class:
## 
##                      Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
## Sensitivity                NA       NA   0.6667   0.5205  0.55556       NA
## Specificity          0.993631  0.96815   0.7805   0.7024  0.90541 0.993631
## Pos Pred Value             NA       NA   0.7353   0.6032  0.26316       NA
## Neg Pred Value             NA       NA   0.7191   0.6277  0.97101       NA
## Prevalence           0.000000  0.00000   0.4777   0.4650  0.05732 0.000000
## Detection Rate       0.000000  0.00000   0.3185   0.2420  0.03185 0.000000
## Detection Prevalence 0.006369  0.03185   0.4331   0.4013  0.12102 0.006369
## Balanced Accuracy          NA       NA   0.7236   0.6115  0.73048       NA
rm(rpart_model, rpart_result)

9 Random Forest


# randomforest
set.seed(1)
rf_model <- randomForest(quality~alcohol+volatile_acidity+citric_acid+
                           density+pH+sulphates,train)
rf_result <- predict(rf_model, newdata = valid[,!colnames(valid) %in% c("quality")])

confusionMatrix(valid$quality,rf_result)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  3  4  5  6  7  8
##          3  0  0  0  1  0  0
##          4  0  0  4  1  0  0
##          5  0  0 54 14  0  0
##          6  0  0 17 40  6  0
##          7  0  0  1  6 11  1
##          8  0  0  0  0  0  1
## 
## Overall Statistics
##                                           
##                Accuracy : 0.6752          
##                  95% CI : (0.5959, 0.7476)
##     No Information Rate : 0.4841          
##     P-Value [Acc > NIR] : 1.028e-06       
##                                           
##                   Kappa : 0.475           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
## Sensitivity                NA       NA   0.7105   0.6452  0.64706 0.500000
## Specificity          0.993631  0.96815   0.8272   0.7579  0.94286 1.000000
## Pos Pred Value             NA       NA   0.7941   0.6349  0.57895 1.000000
## Neg Pred Value             NA       NA   0.7528   0.7660  0.95652 0.993590
## Prevalence           0.000000  0.00000   0.4841   0.3949  0.10828 0.012739
## Detection Rate       0.000000  0.00000   0.3439   0.2548  0.07006 0.006369
## Detection Prevalence 0.006369  0.03185   0.4331   0.4013  0.12102 0.006369
## Balanced Accuracy          NA       NA   0.7688   0.7015  0.79496 0.750000
rm(rf_model, rf_result)

10 SVM


# svm
set.seed(1)
svm_model <- svm(quality~alcohol+volatile_acidity+citric_acid+
                           density+pH+sulphates,train)
svm_result <- predict(svm_model, newdata = valid[,!colnames(valid) %in% c("quality")])

confusionMatrix(valid$quality,svm_result)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  3  4  5  6  7  8
##          3  0  0  1  0  0  0
##          4  0  0  3  2  0  0
##          5  0  0 57 11  0  0
##          6  0  0 20 41  2  0
##          7  0  0  1 12  6  0
##          8  0  0  0  1  0  0
## 
## Overall Statistics
##                                           
##                Accuracy : 0.6624          
##                  95% CI : (0.5827, 0.7359)
##     No Information Rate : 0.5223          
##     P-Value [Acc > NIR] : 0.0002612       
##                                           
##                   Kappa : 0.4339          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
## Sensitivity                NA       NA   0.6951   0.6119  0.75000       NA
## Specificity          0.993631  0.96815   0.8533   0.7556  0.91275 0.993631
## Pos Pred Value             NA       NA   0.8382   0.6508  0.31579       NA
## Neg Pred Value             NA       NA   0.7191   0.7234  0.98551       NA
## Prevalence           0.000000  0.00000   0.5223   0.4268  0.05096 0.000000
## Detection Rate       0.000000  0.00000   0.3631   0.2611  0.03822 0.000000
## Detection Prevalence 0.006369  0.03185   0.4331   0.4013  0.12102 0.006369
## Balanced Accuracy          NA       NA   0.7742   0.6837  0.83138       NA
rm(svm_model, svm_result)

11 xgboost


# xgboost
data.train <- xgb.DMatrix(data = data.matrix(train[, !colnames(valid) %in% c("quality")]), label = train$quality)
data.valid <- xgb.DMatrix(data = data.matrix(valid[, !colnames(valid) %in% c("quality")]))


parameters <- list(
  # General Parameters
  booster            = "gbtree",          # default = "gbtree"
  silent             = 0,                 # default = 0
  # Booster Parameters
  eta                = 0.2,               # default = 0.2, range: [0,1]
  gamma              = 0,                 # default = 0,   range: [0,???]
  max_depth          = 5,                 # default = 5,   range: [1,???]
  min_child_weight   = 2,                 # default = 2,   range: [0,???]
  subsample          = 1,                 # default = 1,   range: (0,1]
  colsample_bytree   = 1,                 # default = 1,   range: (0,1]
  colsample_bylevel  = 1,                 # default = 1,   range: (0,1]
  lambda             = 1,                 # default = 1
  alpha              = 0,                 # default = 0
  # Task Parameters
  objective          = "multi:softmax",   # default = "reg:linear"
  eval_metric        = "merror",
  num_class          = 7,
  seed               = 1               # reproducability seed
)

xgb_model <- xgb.train(parameters, data.train, nrounds = 100)

xgb_pred <- predict(xgb_model, data.valid)

confusionMatrix(as.factor(xgb_pred+2), valid$quality)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  3  4  5  6  7  8
##          3  0  0  0  0  0  0
##          4  0  0  1  1  0  0
##          5  0  4 55 12  2  0
##          6  1  1 12 45  5  0
##          7  0  0  0  4 11  0
##          8  0  0  0  1  1  1
## 
## Overall Statistics
##                                           
##                Accuracy : 0.7134          
##                  95% CI : (0.6359, 0.7826)
##     No Information Rate : 0.4331          
##     P-Value [Acc > NIR] : 1.144e-12       
##                                           
##                   Kappa : 0.5399          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
## Sensitivity          0.000000  0.00000   0.8088   0.7143  0.57895 1.000000
## Specificity          1.000000  0.98684   0.7978   0.7979  0.97101 0.987179
## Pos Pred Value            NaN  0.00000   0.7534   0.7031  0.73333 0.333333
## Neg Pred Value       0.993631  0.96774   0.8452   0.8065  0.94366 1.000000
## Prevalence           0.006369  0.03185   0.4331   0.4013  0.12102 0.006369
## Detection Rate       0.000000  0.00000   0.3503   0.2866  0.07006 0.006369
## Detection Prevalence 0.000000  0.01274   0.4650   0.4076  0.09554 0.019108
## Balanced Accuracy    0.500000  0.49342   0.8033   0.7561  0.77498 0.993590
rm(xgb_model, xgb_pred, data.train, data.valid, parameters)

12 h2o (deeplearning)


# h2o
h2o.init()
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         20 hours 54 minutes 
##     H2O cluster timezone:       America/New_York 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.18.0.8 
##     H2O cluster version age:    24 days  
##     H2O cluster name:           H2O_started_from_R_Owen_slq629 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   1.51 GB 
##     H2O cluster total cores:    8 
##     H2O cluster allowed cores:  8 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     H2O API Extensions:         AutoML, Algos, Core V3, Core V4 
##     R Version:                  R version 3.5.0 (2018-04-23)
h2o.train <- as.h2o(train)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
h2o.valid <- as.h2o(valid)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
h2o.model <- h2o.deeplearning(x = setdiff(names(train), c("quality")),
                              y = "quality",
                              training_frame = h2o.train,
                              # activation = "RectifierWithDropout", # algorithm
                              # input_dropout_ratio = 0.2, # % of inputs dropout
                              # balance_classes = T,
                              # momentum_stable = 0.99,
                              # nesterov_accelerated_gradient = T, # use it for speed
                              epochs = 1000,
                              standardize = TRUE,         # standardize data
                              hidden = c(100, 100),       # 2 layers of 00 nodes each
                              rate = 0.05,                # learning rate
                              seed = 1                # reproducability seed
)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |==                                                               |   3%
  |                                                                       
  |=====                                                            |   7%
  |                                                                       
  |========                                                         |  12%
  |                                                                       
  |==========                                                       |  16%
  |                                                                       
  |============                                                     |  19%
  |                                                                       
  |================                                                 |  24%
  |                                                                       
  |==================                                               |  28%
  |                                                                       
  |=====================                                            |  33%
  |                                                                       
  |=======================                                          |  36%
  |                                                                       
  |==========================                                       |  40%
  |                                                                       
  |=============================                                    |  44%
  |                                                                       
  |===============================                                  |  48%
  |                                                                       
  |==================================                               |  53%
  |                                                                       
  |=====================================                            |  57%
  |                                                                       
  |========================================                         |  62%
  |                                                                       
  |==========================================                       |  65%
  |                                                                       
  |==============================================                   |  70%
  |                                                                       
  |===============================================                  |  73%
  |                                                                       
  |===================================================              |  78%
  |                                                                       
  |=====================================================            |  81%
  |                                                                       
  |=======================================================          |  85%
  |                                                                       
  |==========================================================       |  89%
  |                                                                       
  |============================================================     |  93%
  |                                                                       
  |===============================================================  |  97%
  |                                                                       
  |=================================================================| 100%
h2o.predictions <- h2o.predict(h2o.model, h2o.valid) %>% as.data.frame()
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
confusionMatrix(h2o.predictions$predict, valid$quality)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  3  4  5  6  7  8
##          3  0  0  0  0  0  0
##          4  1  3  2  1  0  0
##          5  0  1 52 11  4  0
##          6  0  0 13 40  1  0
##          7  0  1  1  8 13  0
##          8  0  0  0  3  1  1
## 
## Overall Statistics
##                                           
##                Accuracy : 0.6943          
##                  95% CI : (0.6158, 0.7652)
##     No Information Rate : 0.4331          
##     P-Value [Acc > NIR] : 3.523e-11       
##                                           
##                   Kappa : 0.5333          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
## Sensitivity          0.000000  0.60000   0.7647   0.6349   0.6842 1.000000
## Specificity          1.000000  0.97368   0.8202   0.8511   0.9275 0.974359
## Pos Pred Value            NaN  0.42857   0.7647   0.7407   0.5652 0.200000
## Neg Pred Value       0.993631  0.98667   0.8202   0.7767   0.9552 1.000000
## Prevalence           0.006369  0.03185   0.4331   0.4013   0.1210 0.006369
## Detection Rate       0.000000  0.01911   0.3312   0.2548   0.0828 0.006369
## Detection Prevalence 0.000000  0.04459   0.4331   0.3439   0.1465 0.031847
## Balanced Accuracy    0.500000  0.78684   0.7925   0.7430   0.8059 0.987179
rm(h2o.model, h2o.train, h2o.valid, h2o.predictions)

13 Conclusion


As I expected, xgboost give the best outcome but an ensemble model might potential give improved result.